"""
Phase 9: Sector Decomposition — Crouzet et al. (2025) Test
============================================================
Tests whether demographic investment effort (I/Y) concentrates in
services/construction vs. manufacturing, following the prediction that
aging slows consumer-facing technology diffusion (Crouzet et al. 2025)
while demographic savings channel operates through non-tradable sectors.

Dependent variables:
  - Manufacturing VA / GDP
  - Services VA / GDP
  - Industry VA / GDP (includes manufacturing + mining + construction)
  - Agriculture VA / GDP
  - Non-manufacturing industry VA / GDP (= Industry - Manufacturing ≈ construction + mining)
  - I/Y (capital_intensity) — for comparison with sector shares

If demographics predict I/Y (+) but services VA (+) and manufacturing VA (0 or −),
this confirms investment flows to low-TFP-growth sectors.
"""

import pandas as pd
import numpy as np
from pathlib import Path
import sys
import warnings
warnings.filterwarnings('ignore')

PROJECT_DIR = Path(__file__).resolve().parent.parent
ROOT_DIR = PROJECT_DIR.parent
sys.path.insert(0, str(ROOT_DIR / "multilateral" / "src"))
from model import PanelGLS

DATA = PROJECT_DIR / "data" / "processed"
OUT_TABLES = PROJECT_DIR / "output" / "tables"
OUT_TABLES.mkdir(parents=True, exist_ok=True)

OECD = [
    'AUS', 'AUT', 'BEL', 'CAN', 'CHL', 'COL', 'CRI', 'CZE', 'DNK', 'EST',
    'FIN', 'FRA', 'DEU', 'GRC', 'HUN', 'ISL', 'IRL', 'ISR', 'ITA', 'JPN',
    'KOR', 'LVA', 'LTU', 'LUX', 'MEX', 'NLD', 'NZL', 'NOR', 'POL', 'PRT',
    'SVK', 'SVN', 'ESP', 'SWE', 'CHE', 'TUR', 'GBR', 'USA',
]

CONTROLS = ['fiscal_bal_gdp', 'nfa_gdp_lag', 'rgdp_growth', 'log_rel_opw', 'kaopen']


def stars(p):
    if p < 0.01: return '***'
    if p < 0.05: return '**'
    if p < 0.1: return '*'
    return ''


def run_panel_gls(df, y_var, x_vars, label):
    cols = [y_var] + x_vars + ['iso3', 'year']
    sub = df[cols].dropna()
    if len(sub) < 50:
        print(f"  {label}: insufficient obs ({len(sub)}), skipping")
        return None

    gls = PanelGLS()
    y = sub[y_var].values
    X = sub[x_vars].values
    gls.fit(y, X, sub['iso3'].values, sub['year'].values)

    result = {
        'model': label,
        'dep_var': y_var,
        'n_obs': gls.n_obs,
        'n_countries': gls.n_countries,
        'r_squared': gls.r_squared,
        'rho': gls.rho,
    }
    for i, name in enumerate(x_vars):
        result[f'{name}_coef'] = gls.beta[i]
        result[f'{name}_se'] = gls.se[i]
        result[f'{name}_p'] = gls.pvalues[i]

    print(f"  {label} (N={gls.n_obs}, C={gls.n_countries}, R²={gls.r_squared:.4f})")
    for name in ['Z_1', 'Z_2', 'Z_3']:
        if f'{name}_coef' in result:
            sig = stars(result[f'{name}_p'])
            print(f"    {name:10s} {result[f'{name}_coef']:10.4f} ({result[f'{name}_se']:.4f}) {sig}")

    return result


def fetch_wdi_sector_data():
    """Fetch sector value-added shares from World Bank WDI."""
    import wbgapi as wb

    indicators = {
        'NV.IND.MANF.ZS': 'manuf_va_gdp',
        'NV.SRV.TOTL.ZS': 'services_va_gdp',
        'NV.IND.TOTL.ZS': 'industry_va_gdp',
        'NV.AGR.TOTL.ZS': 'agric_va_gdp',
    }

    frames = []
    for code, name in indicators.items():
        print(f"  Fetching {name} ({code})...")
        df = wb.data.DataFrame(code, time=range(1980, 2024), labels=False)
        df.index.name = 'iso3'
        df = df.stack().reset_index()
        df.columns = ['iso3', 'year', name]
        df['year'] = df['year'].str.replace('YR', '').astype(int)
        frames.append(df)

    merged = frames[0]
    for f in frames[1:]:
        merged = merged.merge(f, on=['iso3', 'year'], how='outer')

    # Construct non-manufacturing industry (≈ construction + mining + utilities)
    merged['nonmanuf_industry_va_gdp'] = merged['industry_va_gdp'] - merged['manuf_va_gdp']

    out_path = DATA / "sector_va_wdi.csv"
    merged.to_csv(out_path, index=False)
    print(f"  Saved: {out_path} ({len(merged)} rows)")
    return merged


def build_table(results, dep_vars_labels):
    """Build markdown table of sector decomposition results."""
    lines = []
    lines.append("# Table A4: Sector Decomposition of Demographic Investment")
    lines.append("")
    lines.append("*Does demographic investment effort concentrate in services or manufacturing?*")
    lines.append("")

    # Model summary
    lines.append("## Model Summary")
    lines.append("")
    lines.append("| Model | Dep Var | Sample | N | Countries | R² | ρ |")
    lines.append("|---|---|---|---|---|---|---|")
    for r in results:
        if r is None:
            continue
        lines.append(f"| {r['model']} | {r['dep_var']} | — | {r['n_obs']} | {r['n_countries']} | {r['r_squared']:.3f} | {r['rho']:.3f} |")

    # Key coefficients
    lines.append("")
    lines.append("## Key Coefficients (Z₁, Z₂, Z₃)")
    lines.append("")
    lines.append("| Model | Variable | Coef | SE | p-value | Sig |")
    lines.append("|---|---|---|---|---|---|")
    for r in results:
        if r is None:
            continue
        for zvar in ['Z_1', 'Z_2', 'Z_3']:
            if f'{zvar}_coef' in r:
                coef = r[f'{zvar}_coef']
                se = r[f'{zvar}_se']
                p = r[f'{zvar}_p']
                lines.append(f"| {r['model']} | {zvar} | {coef:.4f} | {se:.4f} | {p:.4f} | {stars(p)} |")

    lines.append("")
    lines.append(f"*Controls: {', '.join(CONTROLS)}*")
    lines.append("*PanelGLS with AR(1) correction.*")
    lines.append("*Sector VA data: World Bank WDI (NV.IND.MANF.ZS, NV.SRV.TOTL.ZS, NV.IND.TOTL.ZS, NV.AGR.TOTL.ZS).*")

    return '\n'.join(lines)


def main():
    print("=" * 60)
    print("Phase 9: Sector Decomposition")
    print("=" * 60)

    # Load base panel
    panel = pd.read_csv(DATA / "automation_panel.csv")

    # Fetch WDI sector data
    print("\n[1] Fetching WDI sector value-added data...")
    sector_path = DATA / "sector_va_wdi.csv"
    if sector_path.exists():
        print(f"  Using cached: {sector_path}")
        sector = pd.read_csv(sector_path)
    else:
        sector = fetch_wdi_sector_data()

    # Merge
    panel = panel.merge(sector, on=['iso3', 'year'], how='left')

    sector_vars = [
        ('manuf_va_gdp', 'Manufacturing VA (% GDP)'),
        ('services_va_gdp', 'Services VA (% GDP)'),
        ('industry_va_gdp', 'Industry VA (% GDP)'),
        ('agric_va_gdp', 'Agriculture VA (% GDP)'),
        ('nonmanuf_industry_va_gdp', 'Non-Manuf Industry VA (% GDP)'),
        ('capital_intensity', 'Investment/GDP (I/Y)'),
    ]

    # Quick coverage check
    print("\n[2] Coverage check:")
    for var, label in sector_vars:
        if var in panel.columns:
            valid = panel[var].notna().sum()
            c = panel.loc[panel[var].notna(), 'iso3'].nunique()
            print(f"  {label:40s} {valid:>6} obs, {c:>3} countries")

    x_vars = ['Z_1', 'Z_2', 'Z_3'] + CONTROLS
    results = []

    # --- Section A: Full sample ---
    print("\n[3] Full Sample Regressions")
    print("-" * 40)
    for var, label in sector_vars:
        if var not in panel.columns:
            continue
        r = run_panel_gls(panel, var, x_vars, f"A: {label} (Full)")
        results.append(r)

    # --- Section B: OECD ---
    print("\n[4] OECD Subsample")
    print("-" * 40)
    oecd_df = panel[panel['iso3'].isin(OECD)]
    for var, label in sector_vars:
        if var not in oecd_df.columns:
            continue
        r = run_panel_gls(oecd_df, var, x_vars, f"B: {label} (OECD)")
        results.append(r)

    # --- Section C: Non-OECD ---
    print("\n[5] Non-OECD Subsample")
    print("-" * 40)
    non_oecd_df = panel[~panel['iso3'].isin(OECD)]
    for var, label in sector_vars:
        if var not in non_oecd_df.columns:
            continue
        r = run_panel_gls(non_oecd_df, var, x_vars, f"C: {label} (Non-OECD)")
        results.append(r)

    # --- Section D: OADR decomposition ---
    print("\n[6] OADR Decomposition (Full Sample)")
    print("-" * 40)
    x_oadr = ['old_dep', 'youth_dep'] + CONTROLS
    for var, label in sector_vars[:5]:  # sector vars only
        if var not in panel.columns:
            continue
        r = run_panel_gls(panel, var, x_oadr, f"D: {label} (OADR)")
        results.append(r)

    # Build and save table
    results_clean = [r for r in results if r is not None]
    md = build_table(results_clean, sector_vars)
    out_path = OUT_TABLES / "phase9_sector_decomposition.md"
    out_path.write_text(md)
    print(f"\nSaved: {out_path}")

    # Print interpretation summary
    print("\n" + "=" * 60)
    print("INTERPRETATION SUMMARY")
    print("=" * 60)
    for r in results_clean:
        if 'Z_1_coef' in r:
            z1 = r['Z_1_coef']
            p = r['Z_1_p']
            sig = stars(p)
            print(f"  {r['model']:50s} Z₁ = {z1:8.2f} (p={p:.3f}) {sig}")


if __name__ == '__main__':
    main()
